In [ ]:
import json
import pandas as pd
from sklearn.metrics import silhouette_score,calinski_harabasz_score,pairwise_distances
from Kernels.src.Analysis.Clustering import *
from sklearn.cluster import SpectralClustering
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import normalized_mutual_info_score, adjusted_mutual_info_score
from sklearn.metrics import adjusted_rand_score
import itertools
import numpy as np
import matplotlib.pyplot as plt

from statsmodels.graphics.mosaicplot import mosaic
import lifelines
import statsmodels.api as sm
import itertools
from scipy import stats
import plotly.graph_objects as go
from matplotlib.offsetbox import AnchoredText
import matplotlib.font_manager as fm
import seaborn as sns
import matplotlib.colorbar as cb

Set plot params¶

In [ ]:
fontPath = "/CTGlab/home/danilo/.fonts/arial.ttf" 
#From seaborn#########
sns.set_context("paper")
sns.set_style("whitegrid")
#sns.set(font='Helvetica')

#From matplotlib########
#Font:
prop = fm.FontProperties(fname=fontPath)
fm.fontManager.addfont(fontPath)
plt.rcParams['font.family'] = prop.get_name()
#plt.rcParams['font.sans-serif'] = 'Helvetica'
plt.rcParams['font.size'] = 14
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['axes.titlesize'] = 15
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14
plt.rcParams['legend.labelspacing'] = 1.5
plt.rcParams['legend.fontsize'] = 14
# Disable LaTeX interpreter
plt.rcParams['text.usetex'] = False

#lines:
plt.rcParams['lines.linewidth'] = 1
plt.rcParams['axes.linewidth'] = 1
#plt.rcParams['axes.edgecolor'] = '#353935'
plt.rcParams['axes.edgecolor'] = '#222222'
plt.rcParams['axes.titley'] = 1.05 # y is in axes-relative coordinates.
plt.rcParams['axes.titlepad'] = 0

 

Load and Parse data¶

In [ ]:
data_input = pd.read_csv("Input_data/UMAP_METABRIC_4cmp_IC10_full.csv", sep = ",",index_col=0)
data_input=data_input.sample(n=100,axis=0,random_state=42)
df_new_clust=pd.DataFrame()
df_new_clust.index=data_input.index
In [ ]:
#get classical kernel
#SELECT FT
n_qubits=4
features=[]
for i in range(1,int(n_qubits)+1):
    name_cna='Component_'+str(i)
    features.append(name_cna)
labels = 'IntClustMemb'
print(features)

########SAMPLING###################################
y_train=data_input[labels].to_numpy()
X_train=data_input[features]
['Component_1', 'Component_2', 'Component_3', 'Component_4']

Sankey¶

In [ ]:
#Load original patient index from data and metabric data(metadata)
#Load clustering data
print('#### CASE 2000 samples simulated ####')
df_clust=pd.read_csv('/CTGlab/home/valeria/Quantum-Machine-Learning-for-Expression-Data/Results/Unsupervised_2000_umap/clustering_1980_clusters.csv',index_col=0)
df_clust.columns = df_clust.columns.str.replace("_Z_full", "_Z")
df_clust.replace('Z_full','Z',inplace=True)
data_input = pd.read_csv("Input_data/UMAP_METABRIC_4cmp_IC10_full.csv", sep = ",",index_col=0)
patient_index=data_input.index
#set index to df_clust to patient index
df_clust.index=patient_index

#Load metadata
metabric = pd.read_csv('/CTGlab/data/brca_metabric/data_clinical_patient.txt', sep='\t', skiprows=[0,1,2,3])
#Remove patients with missing OS_STATUS and merge with data
metabricSubset = metabric.loc[~metabric['OS_STATUS'].isna()]
quantumDf = pd.merge(metabricSubset, df_clust, how='left', right_on=[df_clust.index], left_on=['PATIENT_ID']).dropna(subset=df_clust.columns)
quantumDf['OS_STATUS_censored'] = quantumDf['VITAL_STATUS'].apply(lambda x: 1 if x == "Died of Disease" else 0) 
quantumDf['OS_STATUS'] = quantumDf['OS_STATUS'].apply(lambda x: int(x.split(':')[0]))
#### CASE 2000 samples simulated ####
In [ ]:
cluster_col=[]
for i in quantumDf.columns:
    if 'Cluster' in i:
        cluster_col.append(i)
In [ ]:
quantumDf_s=quantumDf.copy()
for column in quantumDf_s.loc[:, ['INTCLUST','CLAUDIN_SUBTYPE']+cluster_col].columns:
    quantumDf_s[column] = quantumDf_s[column].apply(lambda x: f"{column}_{x}")
In [ ]:
def plotSankey(df: pd.DataFrame, cols: list, names=None,linkVar=None, valueToHighlight=None):
    """
    Plot a Sankey diagram based on the given dataframe and columns.

    Parameters:
    - df (pd.DataFrame): The dataframe containing the data.
    - cols (list): A list of column names to use for plotting.
    - linkVar (str): The name of the column to use as a link variable (optional).
    - valueToHighlight: The value in the link variable to highlight (optional).

    Returns:
    - fig: The generated plotly figure object.
    """
    colsToPlot = []
    for col in cols:
       colsToPlot.append(list(set(x for x in df[col])))

    mask = {}
    toenumerate = []
    for i in range(len(colsToPlot)):
       toenumerate = toenumerate + colsToPlot[i]
    for k,v in enumerate(toenumerate):
        mask[v] = k

    combinations = []
    for i in range(0, len(colsToPlot)-1):
       combinations = combinations + list(itertools.product(colsToPlot[i], colsToPlot[i+1]))
       

    sources = []
    targets = []
    values = []
    colors = []
    if linkVar != None:
      linkVarValues = df[linkVar].unique()
    for combination in combinations:
        source = mask[combination[0]]
        target = mask[combination[1]]
        column1 = '_'.join(combination[0].split('_')[:-1])
        column2 = '_'.join(combination[1].split('_')[:-1])
        
        if linkVar != None:
          for idx,val in enumerate(linkVarValues):
              flow = df.loc[(df[column1] == combination[0])&(df[column2] == combination[1])&(df[linkVar] == val)].shape[0]
              color = "lightgrey" if val != valueToHighlight else "orange"
              colors.append(color)
              sources.append(source) 
              targets.append(target)
              values.append(flow)
        else:
          flow = df.loc[(df[column1] == combination[0])&(df[column2] == combination[1])].shape[0]
          colors.append('lightgrey')
          sources.append(source) 
          targets.append(target)
          values.append(flow)
    
         
    fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = [f"{x.split('_')[-1]}" for x in toenumerate ]
    ),
    link = dict(
      source = sources, 
      target = targets,
      value = values,
      color = colors,
      label=[f"{x.split('_')[-1]}" for x in toenumerate ]
  ))])
    if names == None:
       names=cols
    for x_coordinate, column_name in enumerate(names):
        fig.add_annotation(
                x=x_coordinate,
                y=1.05,
                xref="x",
                yref="paper",
                text=column_name,
                showarrow=False,
                font=dict(
                    size=16,
                    color="black"
                    ),
                align="center",
                )
    fig.update_layout(
        title='Sample Assignments in Clustering Methods' if valueToHighlight==None else f"Sample Assignments in Clustering Methods\nHighlighted {valueToHighlight}",
        plot_bgcolor='white',
        paper_bgcolor='white',
        font_size=14,
        hovermode='x',
        autosize=False,
        width=1000, height=600,
        xaxis={
        'showgrid': False, # thin lines in the background
        'zeroline': False, # thick line at x=0
        'visible': False,  # numbers below
        },
        yaxis={
        'showgrid': False, # thin lines in the background
        'zeroline': False, # thick line at x=0
        'visible': False,  # numbers below
        }

    ) 
    fig.show()
    return fig

Simulated 2000¶

Let's check the highest scoring clusters for the simulatted 2000 case:

  • Best ones according to the silhouette score (cut-off=0.3)

Get meta data columns to plot

In [ ]:
quantumDf_s.columns[:24]
Out[ ]:
Index(['PATIENT_ID', 'LYMPH_NODES_EXAMINED_POSITIVE', 'NPI', 'CELLULARITY',
       'CHEMOTHERAPY', 'COHORT', 'ER_IHC', 'HER2_SNP6', 'HORMONE_THERAPY',
       'INFERRED_MENOPAUSAL_STATE', 'SEX', 'INTCLUST', 'AGE_AT_DIAGNOSIS',
       'OS_MONTHS', 'OS_STATUS', 'CLAUDIN_SUBTYPE', 'THREEGENE',
       'VITAL_STATUS', 'LATERALITY', 'RADIO_THERAPY', 'HISTOLOGICAL_SUBTYPE',
       'BREAST_SURGERY', 'RFS_STATUS', 'RFS_MONTHS'],
      dtype='object')
In [ ]:
meta_plot=['INTCLUST','CLAUDIN_SUBTYPE','ER_IHC', 'HER2_SNP6','HORMONE_THERAPY','THREEGENE','HISTOLOGICAL_SUBTYPE']
In [ ]:
quantumDf_s=quantumDf.copy()
for column in quantumDf_s.loc[:, meta_plot+cluster_col].columns:
    quantumDf_s[column] = quantumDf_s[column].apply(lambda x: f"{column}_{x}")

Filter configurations

In [ ]:
#Load clustering results
df_2000=pd.read_csv('/CTGlab/home/valeria/Quantum-Machine-Learning-for-Expression-Data/Results/Unsupervised_2000_umap/clustering_1980_opt_k_reviewed.csv',index_col=0)
In [ ]:
df_2000.replace('Z_full','Z',inplace=True)
In [ ]:
#Drop duplicated rbf cases due to bandwidth
df_2000=df_2000[(df_2000['ftmap'] != 'rbf') | (df_2000['Bandwidth'] == 1)]

Lets start with the silhouette

In [ ]:
df_2000.sort_values(by='silhouette',ascending=False,inplace=True)
In [ ]:
#Get only the  10 best with silhouette >0.3
df_2000[df_2000['silhouette']>0.3].sort_values(by='silhouette',ascending=False).head(10)
Out[ ]:
ftmap K Bandwidth s geom_distance concentration silhouette Score_cluster CHI DI v_intra v_inter N_samples
135 ZZ_linear 2 0.125 198446.008917 11.510186 0.066276 0.627849 0.624853 271.129995 0.833597 0.052648 0.000630 1980
28 rbf 3 1.000 126427.404049 2.459801 0.028149 0.568687 0.612915 1002.217980 0.000669 0.008163 0.030924 1980
27 rbf 2 1.000 126427.404049 2.459801 0.028149 0.541753 0.548029 1323.383994 0.000420 0.017752 0.024397 1980
136 ZZ_linear 3 0.125 198446.008917 11.510186 0.066276 0.525904 0.561904 774.054846 0.043874 0.022971 0.055678 1980
29 rbf 4 1.000 126427.404049 2.459801 0.028149 0.427324 0.615735 1054.539243 0.000872 0.005637 0.030072 1980
30 rbf 5 1.000 126427.404049 2.459801 0.028149 0.414847 0.655101 1030.559094 0.000728 0.004759 0.028514 1980
31 rbf 6 1.000 126427.404049 2.459801 0.028149 0.408590 0.687778 949.856227 0.000577 0.004185 0.027981 1980
32 rbf 7 1.000 126427.404049 2.459801 0.028149 0.390088 0.717164 961.774043 0.000453 0.003064 0.027807 1980
152 ZZ_linear 10 0.250 146583.001305 5.823200 0.080595 0.389346 0.600019 649.065192 0.036875 0.023439 0.050366 1980
151 ZZ_linear 9 0.250 146583.001305 5.823200 0.080595 0.373357 0.562661 693.376521 0.035284 0.030125 0.049718 1980

Considering most of these are rbf, and thus not very usefull for us, we can proced and plot only the rbf ones to see cluster consistency between micro and macro clusters and then select one for reference to keep with the quantum

In [ ]:
cols_to_plot_rbf=[]
cols_name_rbf=[]

for i in df_2000[(df_2000['silhouette']>0.3) & (df_2000.ftmap=='rbf')].iloc[:10].itertuples():
    name='Cluster_'+i.ftmap+'_'+str((i.K))+'_'+f'{i.Bandwidth:g}'
    name_col='S:{:.2f}'.format(i.silhouette)
    cols_name_rbf.append(name_col)
    cols_to_plot_rbf.append(name)
In [ ]:
fig=plotSankey(quantumDf_s, cols=cols_to_plot_rbf, names=cols_name_rbf, valueToHighlight=['High'], linkVar='CELLULARITY')

We can see that overall there are two macroclusters (+1 small other) one of which then tend to stratify (cluster 0.0) and the other which seems to be more consistent. we take for reference K=7

In [ ]:
rbf_clusters_s='Cluster_rbf_7_1'

Lets move onand take the best 5 quantum clusters

In [ ]:
df_2000[(df_2000['silhouette']>0.3) & (df_2000.ftmap!='rbf')].iloc[:5]
Out[ ]:
ftmap K Bandwidth s geom_distance concentration silhouette Score_cluster CHI DI v_intra v_inter N_samples
135 ZZ_linear 2 0.125 198446.008917 11.510186 0.066276 0.627849 0.624853 271.129995 0.833597 0.052648 0.000630 1980
136 ZZ_linear 3 0.125 198446.008917 11.510186 0.066276 0.525904 0.561904 774.054846 0.043874 0.022971 0.055678 1980
152 ZZ_linear 10 0.250 146583.001305 5.823200 0.080595 0.389346 0.600019 649.065192 0.036875 0.023439 0.050366 1980
151 ZZ_linear 9 0.250 146583.001305 5.823200 0.080595 0.373357 0.562661 693.376521 0.035284 0.030125 0.049718 1980
111 Z 5 0.500 139813.611325 2.923790 0.055403 0.365701 0.502482 977.285005 0.044605 0.018611 0.041749 1980

Since all the best clusters have been done with the ZZ linear we omit these in further plots to improve readability

In [ ]:
cols_to_plot_q = []
cols_name_q = []

for i in df_2000[(df_2000['silhouette']>0.3) & (df_2000.ftmap!='rbf')].iloc[:5].itertuples():
    name='Cluster_'+i.ftmap+'_'+str((i.K))+'_'+f'{i.Bandwidth:g}'
    name_col='b:{},S:{:.3f}'.format(i.Bandwidth,i.silhouette)
    cols_name_q.append(name_col)
    cols_to_plot_q.append(name)
In [ ]:
fig=plotSankey(quantumDf_s, cols=cols_to_plot_q, names=cols_name_q, valueToHighlight=['High'], linkVar='CELLULARITY')

As for the other cases we see once again how the stratification with lower k rapresents a macrostructure in ore data and tend to have a higher value of s due to the bias of the silhouette score. There seems to be a good consistency also between the finer stratifications (K=10,9,8), and considering they came from the same kernel (ftmap=ZZ_linear, b=0.25) we decide to keep K=10 as a reference since it yields the best silhouette.

In [ ]:
q_cluster_s='Cluster_ZZ_linear_10_0.25'

Now we can put togheter quantum, rbf and known stratification (INTCLUST and CLAUDINE ) and look how they are :

  • quantum vs classical
  • quantum + all meta
In [ ]:
print(q_cluster_s)
Cluster_ZZ_linear_10_0.25
In [ ]:
fig=plotSankey(quantumDf_s, cols=[rbf_clusters_s,q_cluster_s], names=['rbf','ZZ_linear_0.25 best_s'], valueToHighlight=['High'], linkVar='CELLULARITY')
In [ ]:
print('AMI between rbf and ZZ_linear_0.25',adjusted_mutual_info_score(quantumDf_s[rbf_clusters_s],quantumDf_s[q_cluster_s]))
AMI between rbf and ZZ_linear_0.25 0.5680031848012492
In [ ]:
for meta in meta_plot:
    print('AMI between {} and ZZ_linear_0.25'.format(meta),adjusted_mutual_info_score(quantumDf_s[meta],quantumDf_s[q_cluster_s]))
    
AMI between INTCLUST and ZZ_linear_0.25 -0.00026901173333492847
AMI between CLAUDIN_SUBTYPE and ZZ_linear_0.25 -9.660250617278738e-05
AMI between ER_IHC and ZZ_linear_0.25 -0.0019880083791308876
AMI between HER2_SNP6 and ZZ_linear_0.25 0.000838793203366204
AMI between HORMONE_THERAPY and ZZ_linear_0.25 -0.0006368616323537453
AMI between THREEGENE and ZZ_linear_0.25 2.5483342799545463e-05
AMI between HISTOLOGICAL_SUBTYPE and ZZ_linear_0.25 0.0006967353256035525

Quantum and BC stratifications¶

In [ ]:
for meta in meta_plot:
    print('AMI between {} and ZZ_linear_0.25'.format(meta),adjusted_mutual_info_score(quantumDf_s[meta],quantumDf_s[q_cluster_s]))
    fig=plotSankey(quantumDf_s, cols=[q_cluster_s]+[meta] , valueToHighlight=['High'], linkVar='CELLULARITY')
    
AMI between INTCLUST and ZZ_linear_0.25 -0.00026901173333492847
AMI between CLAUDIN_SUBTYPE and ZZ_linear_0.25 -9.660250617278738e-05
AMI between ER_IHC and ZZ_linear_0.25 -0.0019880083791308876
AMI between HER2_SNP6 and ZZ_linear_0.25 0.000838793203366204
AMI between HORMONE_THERAPY and ZZ_linear_0.25 -0.0006368616323537453
AMI between THREEGENE and ZZ_linear_0.25 2.5483342799545463e-05
AMI between HISTOLOGICAL_SUBTYPE and ZZ_linear_0.25 0.0006967353256035525